import torch

def NMSE(x_hat, x):
    x = torch.reshape(x, (len(x), -1))
    x_hat = torch.reshape(x_hat, (len(x_hat), -1))
    
    power = torch.sum(abs(x) ** 2, axis=1)
    mse = torch.sum(abs(x - x_hat) ** 2, axis=1)
    nmse = torch.mean(mse / power)
    return nmse